import matplotlib.pyplot as plt
import numpy as np
from util import *

from alg_stochastic import *


batch_size = 128
mini_batch_size = 10

#############################################################################
iters_sgd = 800  # sgd
iters_scsg_rand = 10000
iters_scsg_coord = 6000
epoch_rand = int(iters_scsg_rand / (batch_size / mini_batch_size))  # scsg-rand
epoch_coord = int(iters_scsg_coord / (batch_size / mini_batch_size))   # scsg-coord

iters_spider = 1000
epoch_spider = int(iters_spider / (batch_size / mini_batch_size))  # spider
epoch_spider_size = 20  # spider

batch_g_scrn = batch_size       # scrn
batch_h_scrn = mini_batch_size  # scrn
iters_scrn = 5000

iters_zpsgd = 1000
#############################################################################

p = 0.01

L = 100      # tuning
rho = 1       # tuning

data_w1a = np.load('data_w1a.npz')
x, y = data_w1a['x'], data_w1a['y']
num = len(x)

# print(x)
# print(y)

lamda, alpha = 1, 1

f = construct_f_stochastic(x, y, lamda, alpha)


w_0 = list(np.zeros(len(x[0])))
print(w_0)

np.random.seed(10)


# fqc_sgd, zo_sgd_ncf_vals = zo_sgd_ncf(f, num, batch_size, w_0, iters_sgd, p, L, rho)
# np.savez('zo_sgd_ncf.npz', fqc_sgd=fqc_sgd, zo_sgd_ncf_vals=zo_sgd_ncf_vals)

# fqc_scsg_rand, zo_scsg_ncf_rand_vals = zo_scsg_ncf_rand(f, num, batch_size, mini_batch_size, w_0, epoch_rand, p, L, rho)
# np.savez('zo_scsg_ncf_rand.npz', fqc_scsg_rand=fqc_scsg_rand, zo_scsg_ncf_rand_vals=zo_scsg_ncf_rand_vals)
#
# fqc_scsg_coord, zo_scsg_ncf_coord_vals = zo_scsg_ncf_coord(f, num, batch_size, mini_batch_size, x_0, epoch_coord, p, L, rho)
# np.savez('zo_scsg_ncf_coord.npz', fqc_scsg_coord=fqc_scsg_coord, zo_scsg_ncf_coord_vals=zo_scsg_ncf_coord_vals)
#
# fqc_spider, zo_spider_ncf_vals = zo_spider_ncf(f, num, batch_size, mini_batch_size, x_0, epoch_spider, epoch_spider_size, p, L, rho)
# np.savez('zo_spider_ncf.npz', fqc_spider=fqc_spider, zo_spider_ncf_vals=zo_spider_ncf_vals)

# fqc_scrn, zo_scrn_vals = zo_scrn(f, num, x_0, iters_scrn, batch_g_scrn, batch_h_scrn, L, rho)



# load data
data_zo_sgd = np.load('zo_sgd_ncf.npz')
data_zo_scsg_rand = np.load('zo_scsg_ncf_rand.npz')
data_zo_scsg_coord = np.load('zo_scsg_ncf_coord.npz')
data_zo_spider = np.load('zo_spider_ncf.npz')

fqc_sgd = data_zo_sgd['fqc_sgd']
zo_sgd_ncf_vals = data_zo_sgd['zo_sgd_ncf_vals']
fqc_scsg_rand = data_zo_scsg_rand['fqc_scsg_rand']
zo_scsg_ncf_rand_vals = data_zo_scsg_rand['zo_scsg_ncf_rand_vals']
fqc_scsg_coord=data_zo_scsg_coord['fqc_scsg_coord']
zo_scsg_ncf_coord_vals=data_zo_scsg_coord['zo_scsg_ncf_coord_vals']
fqc_spider = data_zo_spider['fqc_spider']
zo_spider_ncf_vals = data_zo_spider['zo_spider_ncf_vals']

# print(len(fqc_scsg_rand))

# plot figures
plt.rcParams.update({'font.size': 14})
plt.figure(figsize=(8, 6))

plt.plot(fqc_sgd[0:50], zo_sgd_ncf_vals[0:50], label='ZO-SGD-NCF')
plt.plot(fqc_scsg_coord[0:100], zo_scsg_ncf_coord_vals[0:100], label='ZO-SCSG-NCF (Option I)')
plt.plot(fqc_scsg_rand[0:200], zo_scsg_ncf_rand_vals[0:200], label='ZO-SCSG-NCF (Option II)')
plt.plot(fqc_spider, zo_spider_ncf_vals, label='ZO-SPIDER-NCF')
# plt.plot(fqc_scrn, zo_scrn_vals, label='ZO-SCRN')


plt.xlabel('# of Function Query')
plt.ylabel('Objective Function')
plt.legend()
plt.savefig('figures/least_square_stochastic.pdf', bbox_inches='tight')
plt.show()
